Skip to content

feat: Add SPSA optimization method (Issue #357)#1712

Open
Paramveersingh-S wants to merge 4 commits into
google-deepmind:mainfrom
Paramveersingh-S:main
Open

feat: Add SPSA optimization method (Issue #357)#1712
Paramveersingh-S wants to merge 4 commits into
google-deepmind:mainfrom
Paramveersingh-S:main

Conversation

@Paramveersingh-S

Copy link
Copy Markdown

Addresses #357

Description

This PR implements the Simultaneous Perturbation Stochastic Approximation (SPSA) gradient estimator to address the open feature request #357.

Rather than implementing it as a stateful optax optimizer, it is implemented as a composable gradient estimator (optax.contrib.spsa_estimator). This aligns best with JAX's functional paradigm, allowing users to pass the resulting grad_fn directly into any existing optax optimizer (SGD, Adam, etc.) and optax.chain. Standard polynomial schedules for learning rate and perturbation scaling are also provided.

Verification

I have added rigorous unit tests in tests/contrib/spsa_test.py utilizing chex.all_variants:

  1. Unbiasedness Test: Verified that over 10,000 samples, the expected SPSA gradient strictly matches the true gradient for a multivariable quadratic function.
  2. Optimizer Integration Test: Verified seamless integration with optax.sgd minimizing a noisy objective over 50 steps.
  3. Successfully tested compilation under jax.jit and jax.vmap.

Note on Author Verification: Since SPSA is a classical algorithm by Spall (1998) and not a recent paper, I did not directly email the author. However, the mathematical unbiasedness tests confirm its correctness.

@google-cla

google-cla Bot commented Jun 23, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@Paramveersingh-S Paramveersingh-S force-pushed the main branch 2 times, most recently from d304c4b to 1bc8449 Compare June 23, 2026 06:51
Comment thread optax/contrib/_spsa.py Outdated
def spsa_standard_schedule(
init_value: float,
decay_rate: float,
offset: float = 0.0,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user instantiates this schedule with the defaults, the very first step (count=0) will result in a ZeroDivisionError (or yield inf in JAX). Change the default offset to something mathematically stable, or at least enforce that offset > 0 if count starts at 0.

Comment thread optax/contrib/_spsa.py Outdated
Comment on lines +112 to +114
grad_estimate = jax.tree.map(
lambda d: (y_plus - y_minus) / (2.0 * c) * d, delta
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are recalculating for every single leaf in the PyTree. y_plus, y_minus, and c are all scalars. Calculate this scalar coefficient once outside the tree map, then just apply the multiplication

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do this instead:

scalar_diff = (y_plus - y_minus) / (2.0 * c)
grad_estimate = jax.tree.map(lambda d: scalar_diff * d, delta)

Comment thread optax/contrib/_spsa.py Outdated
# equivalent
# to multiplying by delta_i. We multiply for numerical stability.
grad_estimate = jax.tree.map(
lambda d: (y_plus - y_minus) / (2.0 * c) * d, delta

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need numerical safety here. If c decays to exactly 0 or gets sufficiently small this division will explode.

Comment thread tests/contrib/spsa_test.py Outdated
Comment on lines +88 to +89
self.assertAlmostEqual(val_0, 1.0 / (10.0**0.5))
self.assertAlmostEqual(val_10, 1.0 / (20.0**0.5))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stick to np.testing.assert_allclose

@Paramveersingh-S

Copy link
Copy Markdown
Author

Thanks for the thorough review, @servusdei2018! I've pushed a new commit addressing all of your points:

  1. Schedule stability: I've updated spsa_standard_schedule to default to offset = 1.0 so that count=0 is mathematically stable right out of the box and avoids yielding inf.
  2. Optimized gradient calculation: The scalar division (y_plus - y_minus) / (2.0 * safe_c) is now hoisted and calculated just once before the jax.tree.map, applying only the multiplication across the PyTree leaves.
  3. Numerical safety: I've protected the perturbation scale division using jnp.maximum(c, jnp.finfo(jnp.result_type(c)).eps) to explicitly avoid ZeroDivisionError explosions as c decays.
  4. Testing Standards: I've updated the testing macros in spsa_test.py from assertAlmostEqual over to np.testing.assert_allclose.

All tests pass perfectly locally. Let me know if everything looks good on your end!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants